fix: graceful fallback when attention backends fail to import#13060
Open
sym-bot wants to merge 2 commits intohuggingface:mainfrom
Open
fix: graceful fallback when attention backends fail to import#13060sym-bot wants to merge 2 commits intohuggingface:mainfrom
sym-bot wants to merge 2 commits intohuggingface:mainfrom
Conversation
## Problem External attention backends (flash_attn, xformers, sageattention, etc.) may be installed but fail to import at runtime due to ABI mismatches. For example, when `flash_attn` is compiled against PyTorch 2.4 but used with PyTorch 2.8, the import fails with: ``` OSError: .../flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab ``` The current code uses `importlib.util.find_spec()` to check if packages exist, but this only verifies the package is installed—not that it can actually be imported. When the import fails, diffusers crashes instead of falling back to native PyTorch attention. ## Solution Wrap all external attention backend imports in try-except blocks that catch `ImportError` and `OSError`. On failure: 1. Log a warning message explaining the issue 2. Set the corresponding `_CAN_USE_*` flag to `False` 3. Set the imported functions to `None` This allows diffusers to gracefully degrade to PyTorch's native SDPA (scaled_dot_product_attention) instead of crashing. ## Affected backends - flash_attn (Flash Attention) - flash_attn_3 (Flash Attention 3) - aiter (AMD Instinct) - sageattention (SageAttention) - flex_attention (PyTorch Flex Attention) - torch_npu (Huawei NPU) - torch_xla (TPU/XLA) - xformers (Meta xFormers) ## Testing Tested with PyTorch 2.8.0 and flash_attn 2.7.4.post1 (compiled for PyTorch 2.4). Before: crashes on import. After: logs warning and uses native attention.
DN6
reviewed
Feb 16, 2026
Collaborator
DN6
left a comment
There was a problem hiding this comment.
LGTM 👍🏽 Just some minor requests.
| except (ImportError, OSError) as e: | ||
| # Handle ABI mismatch or other import failures gracefully. | ||
| # This can happen when flash_attn was compiled against a different PyTorch version. | ||
| _flash_attn_logger = get_logger(__name__) |
Collaborator
There was a problem hiding this comment.
I think we can just use add a single logger at the beginning of the file and reuse it instead of creating a dedicated one for each backend.
| try: | ||
| from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
| from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward | ||
| except (ImportError, OSError) as e: |
Collaborator
There was a problem hiding this comment.
Think we can include RuntimeError in the exceptions list as well.
- Move logger to module level instead of creating per-backend loggers - Add RuntimeError to exception list alongside ImportError and OSError Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
External attention backends (flash_attn, xformers, sageattention, etc.) may be installed but fail to import at runtime due to ABI mismatches. For example, when
flash_attnis compiled against PyTorch 2.4 but used with PyTorch 2.8, the import fails with:The current code uses
importlib.util.find_spec()to check if packages exist, but this only verifies the package is installed—not that it can actually be imported. When the import fails, diffusers crashes instead of falling back to native PyTorch attention.Solution
Wrap all external attention backend imports in try-except blocks that catch
ImportErrorandOSError. On failure:_CAN_USE_*flag toFalseNoneThis allows diffusers to gracefully degrade to PyTorch's native SDPA (
scaled_dot_product_attention) instead of crashing.Affected backends
flash_attn(Flash Attention)flash_attn_3(Flash Attention 3)aiter(AMD Instinct)sageattention(SageAttention)flex_attention(PyTorch Flex Attention)torch_npu(Huawei NPU)torch_xla(TPU/XLA)xformers(Meta xFormers)Testing
Tested with PyTorch 2.8.0 and flash_attn 2.7.4.post1 (compiled for PyTorch 2.4).
from diffusers import ...with undefined symbol errorExample warning output